#!/usr/env Python

# SNP_info_retriever.py
#
# USAGE: python SNP_info_retriever.py infile.tsv [-1 -2 -3]
#
# To skip Step 1, add -1 to the command.  To skip Step 2, add -2 to the command.  To skip Step 3, add -3 to the command.
#
# This Python program takes a file with SNP IDs (rs######) and produces the location, retrieved from
# the UCSC Genome Browser using CruzDB.  For more info on CruzDB, see here: https://github.com/brentp/cruzdb
#
# Input structure: This takes a tab-separated values list of rsIDs (SNP IDs), such as:
# rsID		major	minor	X	Y	Z	p-value	NaN
# rs1113396	T	C	0.058	-0.014	0.0063	0.026	230497
#
# Note that the SNP ID should be the first column.  The column containing the p-value can be set below this line.
# In this example shown above, the p-value is in column 7, so the label below is 6 (Python arrays start at 0).
#
###### CUSTOMIZABLE: COLUMN CONTAINING P-VALUE FOR EACH SNP ######
#
pval_col = 6
#
###### END CUSTOMIZATION ######

import cruzdb, sys, time, os, math

hg19 = cruzdb.Genome('hg19')
snp147 = hg19.snp147
t9 = time.clock()

if "-1" in sys.argv:
	print("Skipping step 1.")
else:
	# reading in infile
	try:
		infile = open(sys.argv[1], "r")
		outfile = open(sys.argv[1][:-4] + ".located.txt", "w")
	except IndexError:
		sys.exit("Warning - proper input not supplied on ARGV.\nUSAGE: python SNP_info_retriever.py infile.tsv [-1 -2 -3]")

	# header
	outfile.write("MarkerName\tChromosome\tLocation\tAllele1\tAllele2\tFreq.Allele1.HapMapCEU\tp\n")

	line_count = 0
	processed_count = 0
	error_count = 0

	print("Now getting chromosomal locations for each SNP.")

	for line in infile:
		line_count += 1
		splitline = line.split("\t")
		if splitline[0][:2] == "rs":
			if float(splitline[pval_col].strip()) > 0.95:
				continue
			else:
				processed_count += 1
				info = snp147.filter_by(name=splitline[0]).first()
				split_info = str(info).split("\t")

				# writing to outfile
				try:
					outfile.write("\t".join([splitline[0], split_info[0], split_info[1], splitline[1], splitline[2], splitline[3], splitline[pval_col]]) + "\n")
				except IndexError:
					if str(info) == 'None':
						error_count += 1
				if processed_count % 100 == 0:
					print(str(line_count) + " lines processed so far.")
					print("Most recent: " + "\t".join([str(line_count), str(processed_count), str(error_count), str(info)]))
		else:
			continue

	print("Locations gathered.")
	print("Number of SNPs not found in cruzdb: " + str(error_count))
	infile.close()
	outfile.close()

# now starting on identifying the gene(s), if any, located at each position, again using cruzdb.
t0 = time.clock()
print("Time elapsed for finding SNP locations: " + str(t0-t9) + " seconds.")

if "-2" in sys.argv:
	print("Skipping step 2.")
else:
	# reading in infile, which was generated by the code block above
	infile = open(sys.argv[1][:-4] + ".located.txt", "r")
	outfile = open(sys.argv[1][:-4] + ".gene_loc.txt", "w")

	# header
	outfile.write("MarkerName\tChromosome\tLocation\tAllele1\tAllele2\tFreq.Allele1.HapMapCEU\tp\n")

	# counters
	line_count = 0
	processed_count = 0
	error_count = 0

	print("Now getting gene info for each SNP.")

	# parsing
	for line in infile:
		line_count += 1
		splitline = line.split("\t")
		if splitline[0][:2] == "rs":
			if float(splitline[6].strip()) > 0.95:
				continue
			else:
				processed_count += 1
				chrom = line.split("\t")[1]
				if line.split("\t")[1] == "No_info":		# No chromosome location info available
					continue
				else:
					start = int(line.split("\t")[2])
					end = start + 1
					genes = hg19.bin_query('refGene', chrom, start, end)
					if len(set(g.name2 for g in genes)) == 0:
						outfile.write("INTERGENIC\t" + line)
					else:
						outfile.write("\t".join(["|".join(set(g.name2 for g in genes))] + line.split("\t")))
		# line counter return
		if line_count % 1000 == 0:
			print("Lines processed:\t", str(line_count))

	print("Gene information added for each SNP.")

	infile.close()
	outfile.close()

t1 = time.clock()
print("Time elapsed for adding gene information: " + str(t1-t0) + " seconds.")

if "-3" in sys.argv:
	print("Skipping step 3.")
else:
	# finally, we convert to JSON structure.
	print("Now converting to JSON structured format for input into BigTop.")

	# describing chromosome lengths and total
	chr_names_list = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX', 'chrY']
	chr_lengths_list = [248956422, 242193529, 198295559, 190214555, 181538259, 170805979, 159345973, 145138636, 138394717, 133797422, 135086622, 133275309, 114364328, 107043718, 101991189, 90338345, 83257441, 80373285, 58617616, 64444167, 46709983, 50818468, 156040895, 57227415]

	total = sum(chr_lengths_list)

	# functions
	def convert_to_polar(line):
		splitline = line.strip().split("\t")
		chr_num = splitline[2]
		chr_pos = splitline[3]
		allele_freq = splitline[6]
		pval = splitline[7]
		dist_from_genome_start = int(chr_pos)

		if chr_num in chr_names_list:
			position = chr_names_list.index(chr_num)
			while position >= 0:
				position -= 1
				if position >= 0:
					dist_from_genome_start += chr_lengths_list[position]
	#			dist_from_genome_start += spacer
		else:
			position = ""
		# reduce to polar (between 0 and 2pi - math.pi)
		sigma = float(dist_from_genome_start) / total * 2 * math.pi - (math.pi / 2)
		# CHANGING THE SCALE TO BE BETWEEN 100-1,000 #
		r = (float(allele_freq) * 900) + 100
		y_polar = float(pval)
		return r, sigma, y_polar

	def convert_to_cartesian(r, sigma, y_polar):
		x = r * math.cos(sigma)
		y = -math.log10(y_polar)
		z = r * math.sin(sigma)
		return x, y, z

	infile = open(sys.argv[1][:-4] + ".gene_loc.txt", "r")
	outfile = open(sys.argv[1][:-4] + ".coords.json", "w")

	outfile.write("[\n")

	for line in infile:
		try:
			if line.split("\t")[0] != "MarkerName":
				if line.split("\t")[1] != "No_info":
					polar_coords = convert_to_polar(line)
					cartesian_coords = convert_to_cartesian(polar_coords[0], polar_coords[1], polar_coords[2])

					split = line.strip().split("\t")
					outfile.write(
						"\t{\n" +
						"\t\t\"id\": \"" + split[1] + "\",\n" +
						"\t\t\"gene\": \"" + split[0] + "\",\n" +
						"\t\t\"coords\": [" + ",".join([
							str(cartesian_coords[0]),
							str(cartesian_coords[1]),
							str(cartesian_coords[2])
						]) + "],\n" +
						"\t\t\"chr\": \"" + split[2] + "\",\n" +
						"\t\t\"location\": " + str(split[3]) + ",\n" +
						"\t\t\"frequency\": " + str(split[6]) + ",\n" +
						"\t\t\"p\": " + str(split[7]) + "\n\t},\n"
					)
		except IndexError:
			print line
			print "Exited with error on above line."
			sys.exit()
	outfile.write("]\n")

	print "Successfully converted to JSON structure."
	infile.close()
	outfile.close()

	t2 = time.clock()
	print("Time elapsed for JSON conversion: " + str(t2-t1) + " seconds.")

print("All done!")
